import os
import pickle
import numpy as np
import pandas as pd
import time
from causallearnmain.causallearn.search.FCMBased import lingam
import sys
from utilis.config import ARGConfig
from metaworld_env import make_env
from model.algorithm import CausalSAC
import torch
import torch.nn.functional as F
import ipdb
from dagma_main.src.dagma import linear, nonlinear

def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    
    states, actions, rewards, next_states, dones = memory.sample(sample_size)
    rewards = np.squeeze(rewards[:sample_size]) 
    rewards = np.reshape(rewards, (sample_size, 1))
    X_ori = np.hstack((states[:sample_size,:], actions[:sample_size,:], rewards)) 
    X = pd.DataFrame(X_ori, columns=list(range(np.shape(X_ori)[1])))
    
    if causal_method=='DirectLiNGAM':
        start_time = time.time()  
        model = lingam.DirectLiNGAM()
        model.fit(X)
        end_time = time.time()
        model._running_time = end_time - start_time
        weight_r = model.adjacency_matrix_[-1,np.shape(states)[1]:(np.shape(states)[1]+np.shape(actions)[1])]

    #softmax weight_r
    weight = F.softmax(torch.Tensor(weight_r),0)
    weight = weight.numpy()   
    #* multiply by action size
    weight = weight * weight.shape[0]
    return weight, model._running_time

def get_sa2Q_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    states, actions, rewards, next_states, dones = memory.sample(sample_size)
    states, actions, qvalues = agent.get_Q_value(memory, sample_size)
    actions = actions[:sample_size, :]
    qvalues = qvalues.detach().cpu().numpy()
    X_ori = np.hstack((actions, qvalues))
    X = pd.DataFrame(X_ori, columns=list(range(np.shape(X_ori)[1])))
    if causal_method=='DirectLiNGAM':
        start_time = time.time()  
        model = lingam.DirectLiNGAM()
        model.fit(X)
        end_time = time.time()
        model._running_time = end_time - start_time
        weight_r = model.adjacency_matrix_[-1,:np.shape(actions)[1]]
    
    #softmax weight_r
    weight = F.softmax(torch.Tensor(weight_r),0)
    weight = weight.numpy()   
    #* multiply by action size
    weight = weight * weight.shape[0]
    return weight, model._running_time


def get_sa2r_opti_weight(env, memory, agent, ini_adjacency_matrix, sample_size=50000, causal_method='DagmaNonlinear'):
    states, actions, rewards, next_states, dones = memory.sample(sample_size)
    rewards = np.squeeze(rewards[:sample_size]) 
    rewards = np.reshape(rewards, (sample_size, 1))
    if ini_adjacency_matrix == []:
        ini_adjacency_matrix = np.random.rand(np.shape(states)[1]+np.shape(actions)[1]+np.shape(rewards)[1],np.shape(states)[1]+np.shape(actions)[1]+np.shape(rewards)[1])
    X = np.hstack((states[:sample_size,:], actions[:sample_size,:], rewards)) 

    if causal_method=='DagmaLinear':
        start_time = time.time()  
        model = linear.DagmaLinear(loss_type='l2')
        W_est = model.fit(X, ini_adjacency_matrix, lambda1=0.02)
        end_time = time.time() 
        model._running_time = end_time - start_time
        W_est = W_est.W_est
        weight = W_est.W_est[-1,np.shape(states)[1]:(np.shape(states)[1]+np.shape(actions)[1])]
    elif causal_method=='DagmaNonlinear':
        start_time = time.time()
        dims=[np.shape(states)[1]+np.shape(actions)[1]+np.shape(rewards)[1], 10, 1]
        eq_model = nonlinear.DagmaMLP(dims=dims, bias=True, dtype=torch.float)
        model = nonlinear.DagmaNonlinear(eq_model, dtype=torch.float)
        W_est = model.fit(X, lambda1=0.02, lambda2=0.005, warm_iter=5e4, max_iter=8e4) #* 5e4, 8e4
        end_time = time.time()
        model._running_time = end_time - start_time
        weight = W_est[-1,np.shape(states)[1]:(np.shape(states)[1]+np.shape(actions)[1])]
    
    weight = F.softmax(torch.Tensor(weight),0)
    weight = weight.numpy()
    #* multiply by action size
    weight = weight * weight.shape[0]
    return weight, model._running_time, W_est